Gaussian Process Classification
Preliminary steps
Loading necessary packages
using Plots
using HTTP, CSV
using DataFrames: DataFrame
using AugmentedGaussianProcessesLoading the banana dataset from OpenML
data = HTTP.get("https://www.openml.org/data/get_csv/1586217/phpwRjVjk")
data = CSV.read(data.body, DataFrame)
data.Class[data.Class .== 2] .= -1
data = Matrix(data)
X = data[:, 1:2]
Y = Int.(data[:, end]);We create a function to visualize the data
function plot_data(X, Y; size=(300, 500))
return Plots.scatter(
eachcol(X)...; group=Y, alpha=0.2, markerstrokewidth=0.0, lab="", size=size
)
end
plot_data(X, Y; size=(500, 500))
Run sparse classification with increasing number of inducing points
Ms = [4, 8, 16, 32, 64]
models = Vector{AbstractGPModel}(undef, length(Ms) + 1)
kernel = SqExponentialKernel() ∘ ScaleTransform(1.0)
for (i, num_inducing) in enumerate(Ms)
@info "Training with $(num_inducing) points"
global m = SVGP(
kernel,
LogisticLikelihood(),
AnalyticVI(),
inducingpoints(KmeansAlg(num_inducing), X);
optimiser=false,
Zoptimiser=false,
)
@time train!(m, X, Y, 20)
models[i] = m
end[ Info: Training with 4 points
0.008539 seconds (3.00 k allocations: 16.952 MiB)
[ Info: Training with 8 points
0.010972 seconds (3.02 k allocations: 27.585 MiB)
[ Info: Training with 16 points
0.019171 seconds (3.04 k allocations: 48.981 MiB)
[ Info: Training with 32 points
0.034849 seconds (3.09 k allocations: 92.321 MiB)
[ Info: Training with 64 points
0.166126 seconds (3.37 k allocations: 181.211 MiB, 51.52% gc time)Running the full model
@info "Running full model"
mfull = VGP(X, Y, kernel, LogisticLikelihood(), AnalyticVI(); optimiser=false)
@time train!(mfull, 5)
models[end] = mfullVariational Gaussian Process with a Bernoulli Likelihood with Logistic Link infered by Analytic Variational Inference We create a prediction and plot function on a grid
function compute_grid(model, n_grid=50)
mins = [-3.25, -2.85]
maxs = [3.65, 3.4]
x_lin = range(mins[1], maxs[1]; length=n_grid)
y_lin = range(mins[2], maxs[2]; length=n_grid)
x_grid = Iterators.product(x_lin, y_lin)
y_grid, _ = proba_y(model, vec(collect.(x_grid)))
return y_grid, x_lin, y_lin
end
function plot_model(model, X, Y, title=nothing; size=(300, 500))
n_grid = 50
y_pred, x_lin, y_lin = compute_grid(model, n_grid)
title = if isnothing(title)
(model isa SVGP ? "M = $(AGP.dim(model[1]))" : "full")
else
title
end
p = plot_data(X, Y; size=size)
Plots.contour!(
p,
x_lin,
y_lin,
reshape(y_pred, n_grid, n_grid)';
cbar=false,
levels=[0.5],
fill=false,
color=:black,
linewidth=2.0,
title=title,
)
if model isa SVGP
Plots.scatter!(
p, eachrow(hcat(AGP.Zview(model[1])...))...; msize=2.0, color="black", lab=""
)
end
return p
end;Now run the prediction for every model and visualize the differences
Plots.plot(
plot_model.(models, Ref(X), Ref(Y))...; layout=(1, length(models)), size=(1000, 200)
)
Bayesian SVM vs Logistic
We now create a model with the Bayesian SVM likelihood
mbsvm = VGP(X, Y, kernel, BayesianSVM(), AnalyticVI(); optimiser=false)
@time train!(mbsvm, 5)(Variational Gaussian Process with a Bayesian SVM infered by Analytic Variational Inference , (local_vars = (ω = [0.2670231563482105, 0.3672485011060839, 0.008274352900861275, 0.12202954702121552, 0.24779122630057543, 5.302032576181703, 1.5984465923945335, 1.9596043324856234, 0.6847119281767285, 0.13062766416421334 … 0.04043278780198051, 0.006683983485350634, 1.858697916828438, 1.6652730382811154, 0.01392869774133556, 4.882294630899031, 1.8932642381292637, 2.2637571964758494, 0.5525404030710835, 5.125627317160498], θ = [1.9351985832592284, 1.6501369305279572, 10.993424045976674, 2.8626450424169585, 2.0088940733678022, 0.4342889747621343, 0.7909534686364933, 0.7143578220535454, 1.208498326021844, 2.766829640149746 … 4.973168300220936, 12.23157312148591, 0.7334923587796035, 0.7749207224517213, 8.473147037622837, 0.45257234041463684, 0.7267656360611252, 0.6646378623124665, 1.3452963996556095, 0.44169907240383804]), opt_state = (NamedTuple(),), hyperopt_state = (NamedTuple(),), kernel_matrices = ((K = LinearAlgebra.Cholesky{Float64, Matrix{Float64}}([1.0000499987500624 0.017000746952647305 … 0.6304203262759929 0.377775755997342; 0.017001596968745064 0.9999054828347788 … 0.0007664332243453772 0.26392364754823905; … ; 0.6304518465043205 0.011483977224073186 … 0.01002688609649221 3.4159402959704765e-6; 0.3777946443129457 0.27032117226579383 … 0.38985281832414914 0.010036620197459574], 'U', 0),),)))And compare it with the Logistic likelihood
Plots.plot(
plot_model.(
[models[end], mbsvm], Ref(X), Ref(Y), ["Logistic", "BSVM"]; size=(500, 500)
)...;
layout=(1, 2),
)
This page was generated using Literate.jl.